#coding=utf-8
## ######################################################################### ##
##    OpenVFIFE - Open System for Vector Form Instrinsic                     ##
##                Finite Element Method (VFIFE)                              ##
##                GinkGo(Tan Biao)                                           ##
##                                                                           ##
##                                                                           ##
## (C) Copyright 2021, The GinkGo(Tan Biao). All Rights Reserved.            ##                                      ##
##                                                                           ##
## Commercial use of this program without express permission of              ##
## GinkGo(Tan Biao), is strictly prohibited.  See                            ##
## file 'COPYRIGHT'  in main directory for information on usage and          ##
## redistribution,  and for a DISCLAIMER OF ALL WARRANTIES.                  ##
##                                                                           ##
## Developed by:                                                             ##
##      Tan Biao (ginkgoltd@outlook.com)                                     ##
##                                                                           ##
## ######################################################################### ##

# $Date: 2020-05-17 $
# Written: Tan Biao
# Revised:
#
# Purpose: This file contains the visualizaition functions for postprogress.

# The interface:
#

import os
import re
import warnings
from typing import Dict, Generator, Tuple
import numpy as np
import matplotlib.pyplot as plt


###############################################################################
#                           Helper Functions                                  #
###############################################################################
def isnumber(string:str) -> bool:
    """ Determine whether the string is a valid number through 're' module
    >>> '0.123'.isdigit()
    False
    >>> '0.123'.isnumber()
    True

    Args:
        string (str): a string

    Returns:
        bool: True or False
    """
    pattern = re.compile(r'^[-+]?[-0-9]\d*\.\d*|[-+]?\.?[0-9]\d*$')
    result = pattern.match(string)
    return True if result else False


def read_csv(fname:str, skiprows=1, usecols=None) -> np.ndarray:
    """read csv files, such as coordinates.csv

    Args:
        fname (str): fname with absolute path
        skiprows (int, optional): rows to skip. Defaults to 1.

    Raises:
        FileNotFoundError: if file not exists.

    Returns:
        ndarray: data
    """
    if not os.path.exists(fname):
        raise FileNotFoundError(fname + " is not found.")

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        data = np.loadtxt(fname, delimiter=",", ndmin=2, skiprows=skiprows,
                          usecols=usecols)
    return data


def ndarray_to_dict(arr:np.ndarray) -> Dict:
    """ convert 1d or 2d ndarray to dict, the first column is taken as key
    >>> arr1 = np.array([[1, 2, 3, 4]])
    >>> ndarray_to_dict(arr1)
    {1: array([2, 3, 4])}
    >>> arr2 = np.array([[1, 2, 3, 4], [2, 3, 4, 5]])
    {1: array([2, 3, 4]), 2: array([3, 4, 5])}

    Args:
        arr (np.ndarray): 1d or 2d ndarray

    Raises:
        ValueError: if arr.size == 0
        ValueError: if arr.ndim != 2

    Returns:
        Dict: {key: val}, key == first column of arr, val == other columns
    """
    if arr.size == 0:
        raise ValueError("the inputted array is empty.")
    if arr.ndim != 2:
        raise ValueError("array.ndim != 2")
    return {int(x[0]): x[1:] for x in arr}


###############################################################################
#                               Results Reader                                #
###############################################################################
class ResultFileSystem(object):
    """ To organize the results directory

    Attributes:
        dirs (list): all sub
        results_type (list):

    Methods:
        listdir(): list all valid dirs in results file

    """
    def __init__(self, wd:str) -> None:
        assert os.path.exists(wd), wd + " not exists!"
        self._wd = wd
        # self.dirs = sorted(os.listdir(wd))
        self._subdirs()
        self._results_type()

    def _subdirs(self) -> None:
        dirs = os.listdir(self._wd)
        self.dirs = []
        for d in dirs:
            if not isnumber(d): continue
            self.dirs.append(d)

    def _results_type(self) -> None:
        path = os.path.join(self._wd, self.dirs[0])
        self.results_type = os.listdir(path)

    def listdir(self) -> Generator:
        for d in self.dirs:
            if isnumber(d):
                yield os.path.join(self._wd, d)


class ResultsReader(object):
    """ Read all results at one time step

    Attributes:
        _wd (str): working directory

    Methods:
        read_particle_coordiante():

        read_particle_displace():

        read_particle_velocity():

        read_particle_accelerate():

        read_link_element_force():

        read_cable_element_force():

        read_beam_element_force():

    """
    def __init__(self, wd:str) -> None:
        if not os.path.exists(wd):
            raise FileNotFoundError(wd + " is not found.")
        self._wd = wd

    def read_particle_coordinate(self) -> Dict:
        fname = os.path.join(self._wd, "coordinates.csv")
        coord = read_csv(fname, usecols=(0, 1, 2, 3))
        return ndarray_to_dict(coord)

    def read_particle_displace(self) -> Dict:
        fname = os.path.join(self._wd, 'displace.csv')
        data = read_csv(fname)
        return ndarray_to_dict(data)

    def read_particle_velocity(self) -> Dict:
        fname = os.path.join(self._wd, "velocity.csv")
        data = read_csv(fname)
        return ndarray_to_dict(data)

    def read_particle_accelerate(self) -> Dict:
        fname = os.path.join(self._wd, "accelerate.csv")
        data = read_csv(fname)
        return ndarray_to_dict(data)

    def _read_particle_motion(self) -> Dict:
        items = ["displace", "velocity", "accelerate"]
        ans = {}
        for item in items:
            fname = os.path.join(self._wd, item + ".csv")
            data = read_csv(fname)
            ans[item] = ndarray_to_dict(data)
        return ans

    def read_model_topology(self):
        fname = os.path.join(self._wd, "elements.csv")
        data = read_csv(fname, usecols=(0,2,3))
        return ndarray_to_dict(data)

    def read_link_element_force(self) -> Dict:
        fname = os.path.join(self._wd, "link_element_force.csv")
        force = read_csv(fname)
        return ndarray_to_dict(force)

    def read_beam_element_force(self) -> Dict:
        fname = os.path.join(self._wd, "beam_element_force.csv")
        force = read_csv(fname)
        ans = {}
        for i in range(force.shape[0]//2):
            tmp = {}
            tmp[int(force[2*i,1])] = force[i,2:]
            tmp[int(force[2*i+1,1])] = force[i+1,2:]
            ans[int(force[2*i,0])] = tmp
        return ans

    def read_cable_element_froce(self) -> Dict:
        fname = os.path.join(self._wd, "cable_element_force.csv")
        force = read_csv(fname)
        return ndarray_to_dict(force)

    def collect_particle_info(self) -> Dict:
        coordinates = self.read_particle_coordinate()
        motions = self._read_particle_motion()
        return {'coordinates': coordinates, **motions}

    def collect_element_info(self) -> Dict:
        link = self.read_link_element_force()
        cable = self.read_cable_element_froce()
        beam = self.read_beam_element_force()
        return {'link': link, 'cable': cable, 'beam': beam}


class TimeHistory(object):
    """ Extract time history of particle or element

    Attributes:
        object ([type]): [description]

    Methods:
        extract_particle_motion(id, item):
        extract_link_force(id):
        extract_cable_force(id):
        extract_beam_force(id, node):
    """
    def __init__(self, files:ResultFileSystem) -> None:
        super().__init__()
        assert isinstance(files, ResultFileSystem), """files must be
        ResultFileSystem object"""
        self._files = files

    def extract_particle_motion(self, id:int, item:str) -> np.ndarray:
        ans = []
        for t,d in zip(self._files.dirs, self._files.listdir()):
            tmp = [float(t)]
            reader = ResultsReader(d)
            if item == "displace":
                data = reader.read_particle_displace()
            elif item == "velocity":
                data = reader.read_particle_velocity()
            elif item == "accelerate":
                data = reader.read_particle_accelerate()
            else:
                raise ValueError("Unrecognized item: " + item)
            if id not in data: break
            tmp += data[id].tolist()
            ans.append(tmp)
        ans = np.array(ans)
        # resort
        ind = np.argsort(ans[:,0])
        ans = ans[ind,:]
        return ans

    def extract_link_force(self, id:int) -> np.ndarray:
        ans = []
        for t,d in zip(self._files.dirs, self._files.listdir()):
            tmp = [float(t)]
            reader = ResultsReader(d)
            data = reader.read_link_element_force()
            if id not in data: break
            tmp += data[id].tolist()
            ans.append(tmp)
        ans = np.array(ans)
        # resort
        ind = np.argsort(ans[:,0])
        ans = ans[ind,:]
        return ans

    def extract_cable_force(self, id:int) -> np.ndarray:
        ans = []
        for t,d in zip(self._files.dirs, self._files.listdir()):
            tmp = [float(t)]
            reader = ResultsReader(d)
            data = reader.read_cable_element_force()
            if id not in data: break
            tmp += data[id].tolist()
            ans.append(tmp)
        ans = np.array(ans)
        # resort
        ind = np.argsort(ans[:,0])
        ans = ans[ind,:]
        return ans

    def extract_beam_force(self, id:int, node:int) -> np.ndarray:
        ans = []
        for t,d in zip(self._files.dirs, self._files.listdir()):
            tmp = [float(t)]
            reader = ResultsReader(d)
            data = reader.read_beam_element_force()
            if id not in data: break
            tmp += data[id][node].tolist()
            ans.append(tmp)
        ans = np.array(ans)
        # resort
        ind = np.argsort(ans[:,0])
        ans = ans[ind,:]
        return ans



###############################################################################
#                                  VTK Writer                                 #
###############################################################################
class VtkWriter(object):
    """ Write VTK file at all time steps, only include particle information

    Attributes:
        object ([type]): [description]

    Methods:
        write():

    """
    def __init__(self, files, jobname="jobname", ver=4.0) -> None:
        assert isinstance(files, ResultFileSystem), """job must be
        ResultFileSystem object"""
        self._files = files
        self._jobname = jobname
        self._version = ver
        self._descripts = None
        self._output_dir = None

    def set_descripts(self, descripts:str) -> None:
        self._descripts = descripts

    def set_output_dir(self, dir:str) -> None:
        self._output_dir = os.path.join(dir, self._jobname)
        if not os.path.exists(self._output_dir):
            os.makedirs(self._output_dir)
        # self._output_dir = dir

    @staticmethod
    def gen_file_header(filename, descripts=None, version=4.0, dtype="ASCII"):
        assert dtype in ("ASCII", "BINARY"), """VTK only supports ASCII
        and BINARY type"""

        if os.path.exists(filename):
            raise Warning(filename + " exists, and will be rewrited!")

        with open(filename, "w+") as f:
            f.write("# vtk DataFile Version {v}\n".format(v=version))
            if not descripts:
                f.write("Defaults: Nobody knows what's the hell of this!\n")
            else:
                f.write(descripts.upper()+"\n")
            f.write("{dtype}\n".format(dtype=dtype))

    @staticmethod
    def write2vtk(filename, points, elments, **results):
        assert isinstance(points, dict), "points must be dict"
        assert isinstance(elments, dict), "elements must be dict"

        if not os.path.exists(filename):
            raise ValueError("Please write header first!")

        with open(filename, "a") as f:
            # DATASET type
            f.write("DATASET UNSTRUCTURED_GRID\n")

            # write points
            point_rows, row_points = {}, {}
            num_points = len(points)
            f.write("POINTS {n} float\n".format(n=num_points))
            row_num = 0
            for PID, p in points.items():
                p = [str(x) for x in p]
                f.write(" ".join(p) + "\n")
                point_rows[PID] = row_num
                row_points[row_num] = PID
                row_num += 1
            f.write("\n")

            # write elemnts
            num_elments = len(elments)
            size = 3 * num_elments
            row_elem = {}
            row = 0
            f.write("CELLS {n} {s}\n".format(n=num_elments, s=size))
            for EID, PIDs in elments.items():
                f.write("2 {p1} {p2}\n".format(eid=EID,
                            p1=point_rows[PIDs[0]], p2=point_rows[PIDs[1]]))
                row_elem[row] = EID
                row += 1
            f.write("\n")

            f.write("CELL_TYPES {n}\n".format(n=num_elments))
            for _ in range(num_elments):
                f.write("3\n")
            f.write("\n")

            # write results
            if results:
                f.write("POINT_DATA {n}\n".format(n=num_points))
            if "displace" in results.keys():
                f.write("VECTORS displace float\n")
                for row in range(num_points):
                    disp = results["displace"][row_points[row]]
                    f.write("{x} {y} {z}\n".format(x=disp[0], y=disp[1],
                                                   z=disp[2]))
                f.write("\n")

            if "velocity" in results.keys():
                f.write("VECTORS velocity float\n")
                for row in range(num_points):
                    velo = results["velocity"][row_points[row]]
                    f.write("{x} {y} {z}\n".format(x=velo[0], y=velo[1],
                                                   z=velo[2]))
                f.write("\n")

            if "accelerate" in results.keys():
                f.write("VECTORS accelerate float\n")
                for row in range(num_points):
                    acc = results["accelerate"][row_points[row]]
                    f.write("{x} {y} {z}\n".format(x=acc[0], y=acc[1],
                                                   z=acc[2]))
                f.write("\n")

            # if results:
            #     f.write("CELL_DATA {n}\n".format(n=num_elments))
            # if "element_force" in results:
            #     f.write("SCALARS element_force float 1\n")
            #     f.write("LOOKUP_TABLE default\n")
            #     for row in range(num_elments):
            #         force = results["element_force"][row_elem[row]]
            #         f.write("{x}\n".format(x=force[0]))

    def generate(self) -> None:
        for d in self._files.listdir():
            # read results
            reader = ResultsReader(d)
            coordinates = reader.read_particle_coordinate()
            elems = reader.read_model_topology()
            motion = reader._read_particle_motion()
            # write vtk file
            t = os.path.split(d)[-1]
            if self._output_dir is None:
                raise FileExistsError("Please assign output directory.")
            fname = os.path.join(self._output_dir, self._jobname + t + ".vtk")
            VtkWriter.gen_file_header(fname, descripts=self._descripts)
            VtkWriter.write2vtk(fname, coordinates, elems, **motion)


if __name__ == "__main__":
    wd = "/home/tan/Desktop/test/vfife_test/zc27102_seismic/new/LinkBeamBusrt10"
    # wd = "/home/tan/Desktop/test/vfife_test/zc27102_seismic/LinkBeamSeismic20"
    files = ResultFileSystem(wd)

    jobname = "Tower"
    vtk = VtkWriter(files, jobname)
    descripts = "TowerCollapse"
    vtk.set_descripts(descripts)
    vtk.set_output_dir(wd)
    vtk.generate()
